{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.00114815\n",
      "0.000736224\n",
      "0.00113975\n",
      "3.69 s ± 161 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "The slowest run took 4.78 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
      "1.05 s ± 465 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "18.5 ms ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "import numba as nb\n",
    "import numpy as np\n",
    "\n",
    "def conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):\n",
    "    for i in range(n):\n",
    "        for j in range(out_h):\n",
    "            for p in range(out_w):\n",
    "                window = x[i, ..., j:j+filter_height, p:p+filter_width]\n",
    "                for q in range(n_filters):\n",
    "                    rs[i, q, j, p] += np.sum(w[q] * window)\n",
    "\n",
    "@nb.jit(nopython=True)\n",
    "def jit_conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):\n",
    "    for i in range(n):\n",
    "        for j in range(out_h):\n",
    "            for p in range(out_w):\n",
    "                window = x[i, ..., j:j+filter_height, p:p+filter_width]\n",
    "                for q in range(n_filters):\n",
    "                    rs[i, q, j, p] += np.sum(w[q] * window)\n",
    "\n",
    "def conv(x, w, kernel, args):\n",
    "    n, n_filters = args[0], args[4]\n",
    "    out_h, out_w = args[-2:]\n",
    "    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)\n",
    "    kernel(x, w, rs, *args)\n",
    "    return rs\n",
    "\n",
    "def cs231n_conv(x, w, args):\n",
    "    n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w = args\n",
    "    shape = (n_channels, filter_height, filter_width, n, out_h, out_w)\n",
    "    strides = (height * width, width, 1, n_channels * height * width, width, 1)\n",
    "    strides = x.itemsize * np.asarray(strides)\n",
    "    x_cols = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides).reshape(\n",
    "        n_channels * filter_height * filter_width, n * out_h * out_w)\n",
    "    return w.reshape(n_filters, -1).dot(x_cols).reshape(n_filters, n, out_h, out_w).transpose(1, 0, 2, 3)\n",
    "\n",
    "# 64 个 3 x 28 x 28 的图像输入（模拟 mnist）\n",
    "x = np.random.randn(64, 3, 28, 28).astype(np.float32)\n",
    "# 16 个 5 x 5 的 kernel\n",
    "w = np.random.randn(16, x.shape[1], 5, 5).astype(np.float32)\n",
    "\n",
    "n, n_channels, height, width = x.shape\n",
    "n_filters, _, filter_height, filter_width = w.shape\n",
    "out_h = height - filter_height + 1\n",
    "out_w = width - filter_width + 1\n",
    "args = (n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w)\n",
    "\n",
    "print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, conv_kernel, args)).ravel()))\n",
    "print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, jit_conv_kernel, args)).ravel()))\n",
    "print(np.linalg.norm((conv(x, w, conv_kernel, args) - conv(x, w, jit_conv_kernel, args)).ravel()))\n",
    "%timeit conv(x, w, conv_kernel, args)\n",
    "%timeit conv(x, w, jit_conv_kernel, args)\n",
    "%timeit cs231n_conv(x, w, args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "+ 注意：这里如果使用`np.allclose`的话会过不了`assert`；事实上，仅仅是将数组的`dtype`从`float64`变成`float32`、精度就会下降很多，毕竟卷积涉及到的运算太多"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "288 ms ± 6.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "71.2 ms ± 5.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "8.7 ms ± 62.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "@nb.jit(nopython=True)\n",
    "def jit_conv_kernel2(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):\n",
    "    for i in range(n):\n",
    "        for j in range(out_h):\n",
    "            for p in range(out_w):\n",
    "                for q in range(n_filters):\n",
    "                    for r in range(n_channels):\n",
    "                        for s in range(filter_height):\n",
    "                            for t in range(filter_width):\n",
    "                                rs[i, q, j, p] += x[i, r, j+s, p+t] * w[q, r, s, t]\n",
    "                                \n",
    "assert np.allclose(conv(x, w, jit_conv_kernel, args), conv(x, w, jit_conv_kernel, args))\n",
    "%timeit conv(x, w, jit_conv_kernel, args)\n",
    "%timeit conv(x, w, jit_conv_kernel2, args)\n",
    "%timeit cs231n_conv(x, w, args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "+ 可以看到，使用`jit`和使用纯`numpy`进行编程的很大一点不同就是，不要畏惧用`for`；事实上一般来说，代码“长得越像 C”、速度就会越快"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "696 ms ± 56.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "8.68 ms ± 92.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
      "1.54 ms ± 59.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
     ]
    }
   ],
   "source": [
    "def max_pool_kernel(x, rs, *args):\n",
    "    n, n_channels, pool_height, pool_width, out_h, out_w = args\n",
    "    for i in range(n):\n",
    "        for j in range(n_channels):\n",
    "            for p in range(out_h):\n",
    "                for q in range(out_w):\n",
    "                    window = x[i, j, p:p+pool_height, q:q+pool_width]\n",
    "                    rs[i, j, p, q] += np.max(window)\n",
    "\n",
    "@nb.jit(nopython=True)\n",
    "def jit_max_pool_kernel(x, rs, *args):\n",
    "    n, n_channels, pool_height, pool_width, out_h, out_w = args\n",
    "    for i in range(n):\n",
    "        for j in range(n_channels):\n",
    "            for p in range(out_h):\n",
    "                for q in range(out_w):\n",
    "                    window = x[i, j, p:p+pool_height, q:q+pool_width]\n",
    "                    rs[i, j, p, q] += np.max(window)\n",
    "                    \n",
    "@nb.jit(nopython=True)\n",
    "def jit_max_pool_kernel2(x, rs, *args):\n",
    "    n, n_channels, pool_height, pool_width, out_h, out_w = args\n",
    "    for i in range(n):\n",
    "        for j in range(n_channels):\n",
    "            for p in range(out_h):\n",
    "                for q in range(out_w):\n",
    "                    _max = x[i, j, p, q]\n",
    "                    for r in range(pool_height):\n",
    "                        for s in range(pool_width):\n",
    "                            _tmp = x[i, j, p+r, q+s]\n",
    "                            if _tmp > _max:\n",
    "                                _max = _tmp\n",
    "                    rs[i, j, p, q] += _max\n",
    "\n",
    "def max_pool(x, kernel, args):\n",
    "    n, n_channels = args[:2]\n",
    "    out_h, out_w = args[-2:]\n",
    "    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)\n",
    "    kernel(x, rs, *args)\n",
    "    return rs\n",
    "\n",
    "pool_height, pool_width = 2, 2\n",
    "n, n_channels, height, width = x.shape\n",
    "out_h = height - pool_height + 1\n",
    "out_w = width - pool_width + 1\n",
    "args = (n, n_channels, pool_height, pool_width, out_h, out_w)\n",
    "\n",
    "assert np.allclose(max_pool(x, max_pool_kernel, args), max_pool(x, jit_max_pool_kernel, args))\n",
    "assert np.allclose(max_pool(x, jit_max_pool_kernel, args), max_pool(x, jit_max_pool_kernel2, args))\n",
    "%timeit max_pool(x, max_pool_kernel, args)\n",
    "%timeit max_pool(x, jit_max_pool_kernel, args)\n",
    "%timeit max_pool(x, jit_max_pool_kernel2, args)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:playground]",
   "language": "python",
   "name": "conda-env-playground-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
